#include "vector_query_process.h"
#include "vector.pb.h"
#include "vector_processor.h"
#include "../key_format.h"
#include "../sort_operator/distance_sort_operator.h"

VectorQueryProcess::VectorQueryProcess(const Json::Value& value)
    : QueryProcess(value)
{}

VectorQueryProcess::~VectorQueryProcess(){
    vector_data_.clear();
    vec_dis_map_.clear();
}

int VectorQueryProcess::ParseContent(){
    return ParseContent(ORKEY);
}

int VectorQueryProcess::ParseContent(int logic_type){
    Json::Value::Members member = parse_value_.getMemberNames();
    Json::Value::Members::iterator iter = member.begin();
    std::string field_name;
    Json::Value field_value;
    uint32_t index_type_id = 0;
    bool has_type_id = false;
    for(; iter != member.end(); ++iter){
        Json::Value vec_value = parse_value_[*iter];
        if (INDEX_TYPE_ID == (*iter)){
            if (vec_value.isUInt()){
                index_type_id = vec_value.asUInt();
                has_type_id = true;
            } else {
                log_error("index_type_id should be int.");
                return -RT_PARSE_CONTENT_ERROR;
            }
        } else {
            field_name = (*iter);
            field_value = parse_value_[field_name];
        }
    }
    if(field_name.empty()) {
        log_error("VectorQueryProcess error, field_name is null");
        return -RT_PARSE_CONTENT_ERROR;
    }
    uint32_t segment_tag = 0;

    field_info_.query_type = E_INDEX_READ_TERM;

    uint32_t ui_ret = DBManager::Instance()->GetWordField(segment_tag, component_->Appid(), field_name, field_info_);
    if(ui_ret == 0){
        log_error("field_name:%s error, not in the app_field_define", field_name.c_str());
        return -RT_PARSE_CONTENT_ERROR;
    }
    if(field_value.isArray()){
        for(int i = 0; i < (int)field_value.size(); i++){
            if(field_value[i].isDouble()){
                vector_data_.push_back(field_value[i].asDouble());
            }
        }
    }
    if(index_type_id >= field_info_.index_type_size){
        log_error("index_type_id: %d should be smaller than index_type_size: %d", index_type_id, field_info_.index_type_size);
        return -RT_PARSE_CONTENT_ERROR;
    }
    if((uint32_t)vector_data_.size() != field_info_.dim){
        log_error("vector_data_ size: %d should be equal to dim: %d", (uint32_t)vector_data_.size(), field_info_.dim);
        return -RT_PARSE_CONTENT_ERROR;
    }

    // send vector data to vector service
    VectorReq vec_req;
    vec_req.set_appid(component_->Appid());
    vec_req.set_field_id(field_info_.field);
    vec_req.set_vector_num(1);
    if(has_type_id){
        vec_req.set_index_typeid(index_type_id);  // parse from request
    }
    vec_req.set_topk(component_->PageIndex() * component_->PageSize()); // top k depends on page_size and page_index
    for (auto iter = vector_data_.begin(); iter != vector_data_.end(); iter++){
        vec_req.add_vector_data(*iter);
    }
    string sender;
    bool b_ret = vec_req.SerializeToString(&sender);
    if(false == b_ret){
        SetErrorContext(__LINE__, RT_SERIALIZE_TO_ARRAY_ERROR ,"SerializeToArray got an error");
        return RT_SERIALIZE_TO_ARRAY_ERROR;
    }
    VectorRsp vector_rsp;
    int ret = VectorProcessor::Instance()->SendToVectorService((char *)sender.c_str(), sender.length(), SERVICE_VECTOR_QUERY, vector_rsp);
    if(0 != ret){
        SetErrorContext(__LINE__, ret ,"SendToVectorService got an error");
        return RT_GET_DOC_ERR;
    }
    if(vector_rsp.code() != 0){
        SetErrorContext(__LINE__, vector_rsp.code() , vector_rsp.info());
        return RT_GET_DOC_ERR;
    }
    if(vector_rsp.vector_id_size() == 0){
        SetErrorContext(__LINE__, vector_rsp.code() ,"vector_id_size is 0");
        return RT_GET_DOC_ERR;
    }
    if(vector_rsp.vector_id_size() != vector_rsp.vector_dis_size()){
        SetErrorContext(__LINE__, vector_rsp.code() ,"vector_id_size should be equal to vector_dis_size");
        return RT_GET_DOC_ERR;
    }
    std::vector<FieldInfo> field_info_vec;
    for(int i = 0; i < vector_rsp.vector_id_size(); i++){
        // query doc id by vector id
        uint64_t vector_id = vector_rsp.vector_id(i);
        if(-1 == (int64_t)vector_id){
            continue;
        }
        FieldInfo field_info;
        field_info = field_info_;
        field_info.word = to_string(vector_id);
        double distance = vector_rsp.vector_dis(i);
        vec_dis_map_.insert(make_pair(field_info.word, distance));
        field_info_vec.push_back(field_info);
    }
    component_->AddToFieldList(logic_type, field_info_vec);
    return 0;
}

int VectorQueryProcess::GetValidDoc(){
    if (component_->GetFieldList(ORKEY).empty()){
        return -RT_GET_FIELD_ERROR;
    }
    return GetValidDoc(ORKEY , component_->GetFieldList(ORKEY)[FIRST_TEST_INDEX]);
}

int VectorQueryProcess::GetValidDoc(int logic_type, const std::vector<FieldInfo>& keys){
    log_debug("vector query GetValidDoc beginning...");
    if (0 == keys[FIRST_SPLIT_WORD_INDEX].index_tag){
        return -RT_GET_FIELD_ERROR;
    }

    std::vector<IndexInfo> index_info_vet;
    std::map<std::string, std::string> doc_key_map;
    int iret = ValidDocFilter::Instance()->TextInvertIndexSearch(keys, index_info_vet, doc_key_map);
    if (iret != 0) { return iret; }
    std::vector<IndexInfo> o_valid_index_infos;
    for(auto iter = index_info_vet.begin(); iter != index_info_vet.end(); iter++){
        if(doc_key_map.find(iter->doc_id) != doc_key_map.end()){
            std::string vec_id = doc_key_map.at(iter->doc_id);
            if(vec_dis_map_.find(vec_id) != vec_dis_map_.end()){
                double distance = vec_dis_map_.at(vec_id);
                iter->distance = distance;
                o_valid_index_infos.push_back(*iter);
            }
        }
    }
    ResultContext::Instance()->SetIndexInfos(logic_type , o_valid_index_infos);
    return 0;
}

int VectorQueryProcess::GetScore()
{
    log_debug("vector query GetScore beginning...");
    sort_operator_base_ = new DistanceSortOperator(component_ , doc_manager_);
    p_scoredocid_set_ = sort_operator_base_->GetSortOperator((uint32_t)component_->SortType());
    return 0;
}

void VectorQueryProcess::SortScore(int& i_sequence , int& i_rank)
{
    log_debug("vector query SortScore beginning...");

    if ((SORT_FIELD_DESC == component_->SortType() || SORT_FIELD_ASC == component_->SortType())
        && p_scoredocid_set_->empty()){
        SortByCOrderOp(i_rank);
    }else if (SORT_FIELD_DESC == component_->SortType()
        || DONT_SORT == component_->SortType()){ // 降序和不排序处理
        DescSort(i_sequence , i_rank);
    }else { // 不指定情况下，默认升序，距离近在前
        AscSort(i_sequence , i_rank);
    }
}